import abc
import jax.numpy as np
import jax

import scalevi.utils.utils as utils


class ScaleTransform(abc.ABC):

    def __init__(self):
        pass

    @abc.abstractmethod
    @utils.batch1D_class
    def forward_diag_transform(self, diag):
        pass

    @abc.abstractmethod
    @utils.batch1D_class
    def inverse_diag_transform(self, diag):
        pass

    @utils.batch2D_class
    def forward(self, scale_tril):
        assert scale_tril.ndim == 2
        return np.tril(scale_tril, -1) \
                + np.diag(self.forward_diag_transform(np.diag(scale_tril)))

    @utils.batch2D_class
    def inverse(self, scale_tril):
        assert scale_tril.ndim == 2
        return np.tril(scale_tril, -1) \
                + np.diag(self.inverse_diag_transform(np.diag(scale_tril)))


class ExpScaleTransform(ScaleTransform):

    def __init__(self):
        super(ExpScaleTransform).__init__()

    def forward_diag_transform(self, diag):
        return np.exp(diag)

    def inverse_diag_transform(self, diag):
        return np.log(diag)
    

class LogAddExpScaleTransform(ScaleTransform):

    def __init__(self):
        super(LogAddExpScaleTransform).__init__()

    def forward_diag_transform(self, diag):
        return np.logaddexp(diag, 0)

    def inverse_diag_transform(self, diag):
        """
        A function to implement log (exp(diag) - exp(0)). Will return nan
        if diag<0.
        """
        return diag + np.log1p(-np.exp(-diag)) 


class ProximalScaleTransform(ScaleTransform):

    def __init__(self, gamma):
        self.gamma = gamma
        super(ProximalScaleTransform).__init__()

    def forward_diag_transform(self, diag):
        return diag + 0.5*((diag**2+4*self.gamma)**0.5 - diag)

    def inverse_diag_transform(self, diag):
        return (diag**2 -self.gamma)/diag
